{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nH85BOCo7YYk"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "9tQNAByc7U9g"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nsh9bFNyHJWE"
      },
      "source": [
        "## Fine-Tuning Gemma for Retrieval-Augmented Generation with JORA\n",
        "\n",
        "Scaling Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval-Augmented Generation (RAG), poses significant memory challenges, especially when fine-tuning extensive prompt sequences.\n",
        "\n",
        "[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.\n",
        "\n",
        "Existing open-source libraries support full-model inference and fine-tuning across multiple GPUs but often fall short in efficiently distributing parameters required for retrieved context. To address this limitation, [JORA](https://github.com/aniquetahir/JORA) introduced a novel framework for Parameter-Efficient Fine-Tuning (PEFT) of Llama/Gemma models using distributed training, leveraging [JAX](https://jax.readthedocs.io/en/latest/). This framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources.\n",
        "\n",
        "The experiments demonstrate more than **12x improvement in runtime** compared to [Hugging Face](https://huggingface.co/docs/transformers/en/main_classes/trainer)/[DeepSpeed](https://github.com/microsoft/DeepSpeed) implementations with four GPUs while consuming less than half the VRAM per GPU.\n",
        "\n",
        "In this tutorial, you will understand the end-to-end process of fine-tuning a [Gemma](https://github.com/google/gemma) model using JORA and converting the trained model back to the [Hugging Face](https://huggingface.co/) format for inference.\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_JORA.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>\n",
        "<br><br>\n",
        "\n",
        "[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](\"https://www.kaggle.com/notebooks/welcome?src=https://github.com/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_JORA.ipynb\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BHzVsHf-mCLi"
      },
      "source": [
        "## Setup\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p1TbFDxPoWGd"
      },
      "source": [
        "### Selecting the Runtime Environment\n",
        "\n",
        "To start, you can choose either **Google Colab** or **Kaggle** as your platform. Select one, and proceed from there.\n",
        "\n",
        "- #### **Google Colab** <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png\" alt=\"Google Colab\" width=\"30\"/>\n",
        "\n",
        "  1. Click **Open in Colab**.\n",
        "  2. You'll need access to a [**Colab Pro/Pro+**](https://colab.research.google.com/signup) runtime with sufficient resources to run the Gemma model.\n",
        "  3. In the menu, go to **Runtime** > **Change runtime type**.\n",
        "  4. Ensure that the **GPU** is set to **A100**.\n",
        "\n",
        "- #### **Kaggle** <img src=\"https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png\" alt=\"Kaggle\" width=\"40\"/>\n",
        "\n",
        "  1. Click **Open in Kaggle**.\n",
        "  2. Click on **Session options** in the right sidebar.\n",
        "  3. Under **Accelerator**, select **GPU T4 x2**.\n",
        "     - Note: This instance comes with **15 GB x2** (15 GB for each T4 GPU) of VRAM and **30 GB** of RAM.\n",
        "  4. Save the settings, and the notebook will restart with GPU support."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RFm2S0Gijqo8"
      },
      "source": [
        "### Gemma setup\n",
        "\n",
        "#### **Kaggle Models**\n",
        "\n",
        "To complete this tutorial and download and fine-tune using the necessary Kaggle Gemma Flax models, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
        "\n",
        "* Get access to Gemma on kaggle.com.\n",
        "* Select a Colab/Kaggle runtime with sufficient resources to run\n",
        "  the Gemma model.\n",
        "* You'll generate and configure a Kaggle username and an API key as Colab secrets later in the guide.\n",
        "\n",
        "#### **Hugging Face Hub**\n",
        "\n",
        "You'll also be logging in to Hugging Face Hub to download the exact Gemma model used while fine-tuning so that you can convert the Flax model to the Hugging Face format and run inference later. Let's get you set up with Gemma:\n",
        "\n",
        "1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n",
        "2. **Gemma Model Access:** Head over to the [Gemma model page](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) and accept the usage conditions.\n",
        "3. **Colab/Kaggle with Gemma Power:**  For this tutorial, you'll need a Colab/Kaggle runtime with enough resources to handle the Gemma model. Choose an appropriate runtime when starting your Colab/Kaggle session.\n",
        "4. **Hugging Face Token:**  Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). This token will come in handy later.\n",
        "\n",
        "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CY2kGtsyYpHF"
      },
      "source": [
        "### Configure Your Credentials\n",
        "\n",
        "To access private models and datasets, you need to log in to the Hugging Face (HF) and Kaggle ecosystem.\n",
        "\n",
        "- #### **Google Colab** <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png\" alt=\"Google Colab\" width=\"30\"/>\n",
        "  If you're using Colab, you can securely store your Hugging Face token (`HF_TOKEN`) using the Colab Secrets manager:\n",
        "  1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "  2. **Add Hugging Face Token**:\n",
        "    - Create a new secret with the **name** `HF_TOKEN`.\n",
        "    - Copy/paste your token key into the **Value** input box of `HF_TOKEN`.\n",
        "    - **Toggle** the button on the left to allow notebook access to the secret\n",
        "  3. **Add Kaggle Token**:\n",
        "    - Same as before, but you repeat it for `KAGGLE_USERNAME` and `KAGGLE_KEY`.\n",
        "\n",
        "\n",
        "- #### **Kaggle** <img src=\"https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png\" alt=\"Kaggle\" width=\"40\"/>\n",
        "  To securely use your Hugging Face token (`HF_TOKEN`) in this notebook, you'll need to add it as a secret in your Kaggle environment:  \n",
        "  1. Open your Kaggle notebook and locate the **Addons** menu at the top in your notebook interface.\n",
        "  2. Click on **Secrets** to manage your environment secrets.  \n",
        "  <img src=\"https://i.imgur.com/vxrtJuM.png\" alt=\"The Secrets option is found at the top.\" width=50%>\n",
        "  3. **Add Hugging Face Token**:\n",
        "      - Click on the **Add secret** button.\n",
        "      - In the **Label** field, enter `HF_TOKEN`.  \n",
        "      - In the **Value** field, paste your Hugging Face token.\n",
        "      - Click **Save** to add the secret.\n",
        "  4. **Add Kaggle Token**:\n",
        "      - Same as before, but you repeat it for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7-1PYEuJuJyN"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "\n",
        "if 'google.colab' in sys.modules:\n",
        "    from google.colab import userdata\n",
        "    # Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "    # vars as appropriate for your system.\n",
        "    os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")\n",
        "    os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "    os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
        "elif os.path.exists('/kaggle/working'):\n",
        "    from kaggle_secrets import UserSecretsClient\n",
        "    user_secrets = UserSecretsClient()\n",
        "    os.environ['HF_TOKEN'] = user_secrets.get_secret(\"HF_TOKEN\")\n",
        "    os.environ[\"KAGGLE_USERNAME\"] = user_secrets.get_secret('KAGGLE_USERNAME')\n",
        "    os.environ[\"KAGGLE_KEY\"] = user_secrets.get_secret('KAGGLE_KEY')\n",
        "else:\n",
        "    raise RuntimeError(\n",
        "        \"Unsupported runtime environment detected.\\n\"\n",
        "        \"This notebook currently supports execution on Google Colab or Kaggle.\\n\"\n",
        "        \"Please ensure you are running in one of these environments.\\n\"\n",
        "        \"If you are running locally or on a different platform, manually set the following environment variables:\\n\"\n",
        "        \" - HF_TOKEN\\n\"\n",
        "        \" - KAGGLE_USERNAME\\n\"\n",
        "        \" - KAGGLE_KEY\\n\\n\"\n",
        "        \"You can set environment variables in your terminal or within your Python notebook before running any cells.\"\n",
        "    )\n",
        "\n",
        "# Disable progress bar to prevent verbose logging by kagglehub\n",
        "os.environ[\"TQDM_DISABLE\"] = \"1\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J9Pj2Y9EHJWF"
      },
      "source": [
        "### Clone **JORA** and install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ryPIut33HJWF"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Cloning into 'JORA'...\n",
            "remote: Enumerating objects: 299, done.\u001b[K\n",
            "remote: Counting objects: 100% (299/299), done.\u001b[K\n",
            "remote: Compressing objects: 100% (216/216), done.\u001b[K\n",
            "remote: Total 299 (delta 151), reused 203 (delta 71), pack-reused 0 (from 0)\u001b[K\n",
            "Receiving objects: 100% (299/299), 6.99 MiB | 17.66 MiB/s, done.\n",
            "Resolving deltas: 100% (151/151), done.\n",
            "/content/JORA\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.2/87.2 kB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.1/57.1 MB\u001b[0m \u001b[31m42.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m320.1/320.1 kB\u001b[0m \u001b[31m30.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.9/94.9 kB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.1/11.1 MB\u001b[0m \u001b[31m122.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.2/73.2 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.8/63.8 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m130.2/130.2 kB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for gemma (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Looking in links: https://storage.googleapis.com/jax-releases/jax_nightly_releases.html\n",
            "Requirement already satisfied: jax==0.4.33 in /usr/local/lib/python3.10/dist-packages (0.4.33)\n",
            "Requirement already satisfied: jaxlib==0.4.33 in /usr/local/lib/python3.10/dist-packages (0.4.33)\n",
            "Requirement already satisfied: jax-cuda12-plugin==0.4.33 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (0.4.33)\n",
            "Requirement already satisfied: jax-cuda12-pjrt==0.4.33 in /usr/local/lib/python3.10/dist-packages (0.4.33)\n",
            "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax==0.4.33) (0.4.1)\n",
            "Requirement already satisfied: numpy>=1.24 in /usr/local/lib/python3.10/dist-packages (from jax==0.4.33) (1.26.4)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax==0.4.33) (3.4.0)\n",
            "Requirement already satisfied: scipy>=1.10 in /usr/local/lib/python3.10/dist-packages (from jax==0.4.33) (1.13.1)\n",
            "Requirement already satisfied: nvidia-cublas-cu12>=12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (12.6.3.3)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12>=12.1.105 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (12.6.80)\n",
            "Requirement already satisfied: nvidia-cuda-nvcc-cu12>=12.1.105 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12>=12.1.105 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12<10.0,>=9.1 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (9.5.1.17)\n",
            "Requirement already satisfied: nvidia-cufft-cu12>=11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (11.3.0.4)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12>=11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (11.7.1.2)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12>=12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (12.5.4.2)\n",
            "Requirement already satisfied: nvidia-nccl-cu12>=2.18.1 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (2.23.4)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12>=12.1.105 in /usr/local/lib/python3.10/dist-packages (from jax-cuda12-plugin[with_cuda]==0.4.33) (12.6.77)\n"
          ]
        }
      ],
      "source": [
        "# Clone the JORA repository and install the requirements\n",
        "!git clone https://github.com/aniquetahir/JORA.git\n",
        "%cd JORA\n",
        "!pip install -q -e .\n",
        "\n",
        "# Install google-deepmind/gemma as it's a required dependency for JORA\n",
        "!pip install -q git+https://github.com/google-deepmind/gemma.git\n",
        "\n",
        "# Install the appropriate JAX version\n",
        "JAX_VERSION = \"0.4.33\"\n",
        "!pip install -U --pre -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \\\n",
        "  jax==$JAX_VERSION jaxlib==$JAX_VERSION \\\n",
        "  jax-cuda12-plugin[with_cuda]==$JAX_VERSION jax-cuda12-pjrt==$JAX_VERSION"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UMa4gjjvM_8N"
      },
      "source": [
        "### Import the dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FHfyqXc7SuBc"
      },
      "outputs": [],
      "source": [
        "# Patch JORA's initialisation.py file to be compatible with the latest JAX version\n",
        "\n",
        "!sed -i \"s/jax\\.config\\.update('jax_default_matmul_precision', *jax\\.lax\\.Precision\\.HIGHEST)/jax.config.update('jax_default_matmul_precision', 'bfloat16')/\" jora/lib/proc_init_utils/initialisation.py"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8einQYITNKxR"
      },
      "outputs": [],
      "source": [
        "import kagglehub\n",
        "import jax\n",
        "import jora\n",
        "import pathlib\n",
        "import torch\n",
        "\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from huggingface_hub import snapshot_download"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7gEJeEXWHJWG"
      },
      "source": [
        "## Download the Gemma Model\n",
        "\n",
        "Now, you can download the Gemma model using `kagglehub`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QpqxVCRmjIeY"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "415b8964c1b04f24856fad5e6bffabe2",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading 11 files:   0%|          | 0/11 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/_METADATA...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/_CHECKPOINT_METADATA...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/bf69258061ae5f35eb7a5669fe6877d4...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/d/b5a4695f4be0a2f41ec1e25616ebd7e7...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/834bb4bf1e3854eb09f6208c95c071b2...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/descriptor/descriptor.pbtxt...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/manifest.ocdbt...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/fc20151969d7ca91ea9d8275bda0e219...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/manifest.ocdbt...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/tokenizer.model...\n",
            "Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2/Flax/gemma2-2b-it/1/download/gemma2-2b-it/checkpoint...\n",
            "GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma-2/Flax/gemma2-2b-it/1\n"
          ]
        }
      ],
      "source": [
        "VARIANT = \"gemma2-2b-it\"\n",
        "GEMMA_PATH = kagglehub.model_download(f'google/gemma-2/Flax/{VARIANT}')\n",
        "print('GEMMA_PATH:', GEMMA_PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ponZxNy3sfwS"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'7b-it', '2b-it', 'gemma2-2b-it', '7b', '2b'}\n",
            "{'2b': GemmaConfig(n_heads=8, n_kv=1), '2b-it': GemmaConfig(n_heads=8, n_kv=1), '7b': GemmaConfig(n_heads=16, n_kv=16), '7b-it': GemmaConfig(n_heads=16, n_kv=16), '1.1-2b-it': GemmaConfig(n_heads=8, n_kv=1), '1.1-7b-it': GemmaConfig(n_heads=16, n_kv=16), 'gemma2-2b-it': GemmaConfig(n_heads=8, n_kv=1)}\n"
          ]
        }
      ],
      "source": [
        "# Note: JORA only supports loading Gemma and Gemma 1.1 models at the moment\n",
        "# Let's add an entry for `gemma2-2b-it` so that the Gemma 2 model can be\n",
        "# discoverable by JORA\n",
        "\n",
        "# Allow JORA to discover the newly downloaded Gemma 2 model\n",
        "JORA_GEMMA_VERSIONS = jora.lib.gemma.gemma_config.GEMMA_VERSIONS\n",
        "JORA_GEMMA_VERSIONS = JORA_GEMMA_VERSIONS.add('gemma2-2b-it')\n",
        "print(jora.lib.gemma.gemma_config.GEMMA_VERSIONS)\n",
        "\n",
        "JORA_GEMMA_MODEL_MAPPING = jora.lib.gemma.common.model_config_mapping\n",
        "JORA_GEMMA_MODEL_MAPPING = JORA_GEMMA_MODEL_MAPPING.update({\n",
        "    'gemma2-2b-it': jora.lib.gemma.gemma_config.GemmaConfig2B\n",
        "})\n",
        "print(jora.lib.gemma.common.model_config_mapping)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KnB1B5TxHJWH"
      },
      "source": [
        "**Note:** By default, `kagglehub` stores the model in the `~/.cache/kagglehub` directory.\n",
        "\n",
        "Verify that JAX recognizes the GPU devices:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FsLysmfPHJWH"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[CudaDevice(id=0)]\n"
          ]
        }
      ],
      "source": [
        "print(jax.devices())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UB8Xxe9RHJWH"
      },
      "source": [
        "## Configure JORA and Prepare the Dataset\n",
        "\n",
        "Here, you'll configure the Gemma model and also the training process for **LoRA** fine-tuning.\n",
        "\n",
        "In order to fine-tune Gemma, you will use the **Alpaca** dataset. Ensure you have the dataset file `alpaca_data_cleaned.json` in the appropriate directory. You can download it from [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data_cleaned.json) or use the one that's bundled in the repository. For demonstration purposes, let's use the bundled one.\n",
        "\n",
        "**Credits:** [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json)\n",
        "\n",
        "The `generate_alpaca_dataset` function is used to generate the dataset from an Alpaca format JSON file. This helps with instruct format training since the dataset processing, tokenization, and batching is handled by the library. Alternatively, torch `Dataset` and `DataLoader` can be used for custom datasets.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XyWf7EVpHJWH"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Processing data...\n"
          ]
        }
      ],
      "source": [
        "# Configure the model and training parameters\n",
        "config = jora.ParagemmaConfig(\n",
        "    # Feel free to tweak these parameters\n",
        "    N_EPOCHS=1,\n",
        "    LORA_R=8,\n",
        "    # Note: The `LORA_DROPOUT` parameter is currently not configurable.\n",
        "    # https://github.com/aniquetahir/JORA?tab=readme-ov-file#contributing\n",
        "    LORA_ALPHA=16,\n",
        "    LR=1e-5,\n",
        "    BATCH_SIZE=2,\n",
        "    N_ACCUMULATION_STEPS=8,\n",
        "    GEMMA_MODEL_PATH=GEMMA_PATH,\n",
        "    MAX_SEQ_LEN=512,\n",
        "    MODEL_VERSION=VARIANT\n",
        ")\n",
        "\n",
        "# Path to the Alpaca dataset\n",
        "dataset_path = 'jora/alpaca_data_cleaned.json'\n",
        "\n",
        "# Generate the dataset with a 20% split for prototyping.\n",
        "# When running on Kaggle, set split_percentage to 0.005 to use a smaller subset\n",
        "# for quicker demonstration purposes.\n",
        "dataset = jora.generate_alpaca_dataset_gemma(\n",
        "    dataset_path, 'train', config,\n",
        "    # Change the split percentage to '0.005` if you're on Kaggle\n",
        "    split_percentage=0.2,\n",
        "    alpaca_mix=0.3\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dMDYXOKKHJWH"
      },
      "source": [
        "The `ParagemmaConfig` class is used to set up the configuration for training while `generate_alpaca_dataset_gemma` processes the dataset, handles tokenization, and prepares it for training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q0_iDvXsIRSx"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "ParagemmaConfig(GEMMA_MODEL_PATH='/root/.cache/kagglehub/models/google/gemma-2/Flax/gemma2-2b-it/1', MODEL_VERSION='gemma2-2b-it', NUM_SHARDS=None, LORA_R=8, LORA_ALPHA=16, LORA_DROPOUT=0.05, LR=1e-05, BATCH_SIZE=2, N_ACCUMULATION_STEPS=8, MAX_SEQ_LEN=512, N_EPOCHS=1, SEED=420, CACHE_SIZE=30)"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "config"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3FguwkPZHJWH"
      },
      "source": [
        "## Fine-tune Gemma with **JORA**\n",
        "\n",
        "Now, you can proceed to fine-tuning the model using the `train_lora_gemma` function which initiates the fine-tuning process using LoRA (Low-Rank Adaptation). The checkpoints will be saved in the folder specified by `checkpoint_path`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xwr5ObmKHJWH"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Successfully loaded and sharded model parameters!\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "bd70d0ae19ff4e999956494f432803d2",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Output()"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
            ],
            "text/plain": []
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Path to the trained LoRA weights\n",
        "checkpoint_path = 'checkpoints'\n",
        "jora.train_lora_gemma(config, dataset, checkpoint_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u3D9iMEfHJWH"
      },
      "source": [
        "**Note**: Fine-tuning on the entire dataset can be time-consuming and may exceed available GPU quotas on **Kaggle** or consume significant compute units on **Google Colab**. Using a smaller split helps in managing resource usage and staying within platform-imposed limits."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FcbH1MCNHJWH"
      },
      "source": [
        "## Convert the model to the **Hugging Face Format**\n",
        "\n",
        "After fine-tuning, you need to convert the trained model to the Hugging Face format for compatibility with the Hugging Face ecosystem so that you can easily run inference later.\n",
        "\n",
        "**Usage:**\n",
        "\n",
        "```python\n",
        "lorize_huggingface(HUGGINGFACE_PATH, JAX_PATH, SAVE_PATH, gemma=True)\n",
        "```\n",
        "\n",
        "- **HUGGINGFACE_PATH**: Path to the Hugging Face Gemma model (the base model before fine-tuning).\n",
        "- **JAX_PATH**: Path to the LoRA merged parameters (the trained LoRA weights).\n",
        "- **SAVE_PATH**: Path to save the fine-tuned Hugging Face Gemma model.\n",
        "- **gemma**: Flag indicating you're working with a Gemma model.\n",
        "\n",
        "First, specify the paths:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E0TaWlbFHJWI"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6908f9be1f6f4df4b5baa453959335b2",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f9cdc318b0cb4c5e9af105002ceb4ec6",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9c5103f35a6a4c16b167b19ec8cc2e85",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3745cf654bc94c1f915d3fff2e6817a0",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "68c424b9f2a945fbb5492a26aa65d577",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "065934411f65442aad4d282f23e4c397",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md:   0%|          | 0.00/29.1k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6c5e24bf5c7f46c9b46038459cf180cf",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              ".gitattributes:   0%|          | 0.00/1.57k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2bf4aeec7807495eb03c9dbeab9abacf",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e8ef83fb6b9e41b7857e9f5ad336cbc1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "cbe1cb73ee3d4880b4589c3645ab270e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "eca3af8d3a1d4e1baf01bf745eb8ffe6",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "60003984d3d743d6a501a81b854191d9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Specify the repository\n",
        "repo_id = \"google/gemma-2-2b-it\"\n",
        "local_dir = 'pretrained'\n",
        "\n",
        "snapshot_download(\n",
        "    repo_id=repo_id,\n",
        "    local_dir=local_dir,\n",
        "    revision=\"main\",\n",
        "    ignore_patterns=['*.gguf']\n",
        ")\n",
        "\n",
        "HUGGINGFACE_PATH = local_dir\n",
        "JAX_PATH = 'checkpoints/jax_lora_final.pickle'\n",
        "SAVE_PATH = 'gemma-ft'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4zDE8A6KHJWI"
      },
      "source": [
        "Then, run the converter:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rUpmZXMhUWNz"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7e7b3b8e8bab45c3be864576a74ba20d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "model loaded\n",
            "model saved to gemma-ft\n"
          ]
        }
      ],
      "source": [
        "from jora.hf.__main__ import lorize_huggingface\n",
        "\n",
        "lorize_huggingface(HUGGINGFACE_PATH, JAX_PATH, SAVE_PATH, gemma=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fimTt20qHJWI"
      },
      "source": [
        "- The `jora.hf` module converts the JAX-trained model back to the Hugging Face format.\n",
        "- It merges the LoRA weights with the original model parameters.\n",
        "- The converted model is saved in the specified `SAVE_PATH`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mH91P3gRHJWI"
      },
      "source": [
        "## Load the Model and Generate Text\n",
        "\n",
        "Finally, you can load the converted model using Hugging Face's Transformers library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "romkTs-LO7P6"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "94c63e99ccbf429eb630115951e44dbf",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu.\n"
          ]
        }
      ],
      "source": [
        "tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_PATH)\n",
        "model = AutoModelForCausalLM.from_pretrained(SAVE_PATH, device_map=\"auto\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "acWcwfqIHJWI"
      },
      "source": [
        "Here, both the tokenizer and the model are first loaded and then the model is moved automatically to the appropriate device. Finally, you generate text using the model while relying on the Alpaca prompt format:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C32TYrMgHJWI"
      },
      "outputs": [],
      "source": [
        "# Define the Alpaca prompt template\n",
        "alpaca_prompt = \"\"\"\\\n",
        "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
        "\n",
        "### Instruction:\n",
        "{}\n",
        "\n",
        "### Input:\n",
        "{}\n",
        "\n",
        "### Response:\n",
        "\"\"\"\n",
        "\n",
        "# Function to generate response\n",
        "def generate_response(instruction, input_text=\"\", max_new_tokens=384):\n",
        "    prompt = alpaca_prompt.format(instruction, input_text)\n",
        "    device = \"cuda\"\n",
        "    inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
        "    outputs = model.generate(inputs, max_new_tokens=max_new_tokens)\n",
        "    text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "    print(text)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CdkScc4Kx9l0"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
            "\n",
            "### Instruction:\n",
            "Identify 3 common mistakes in the following sentence. Suggest changes.\n",
            "\n",
            "### Input:\n",
            "She seems to believe that the real key to sucsess is working smart and hard.\n",
            "\n",
            "### Response:\n",
            "1. \"sucsess\" should be \"success\"\n",
            "2. \"seems to believe\" is a weak phrase.\n",
            "3. \"working smart and hard\" is a cliché.\n"
          ]
        }
      ],
      "source": [
        "generate_response(\n",
        "    instruction=\"Identify 3 common mistakes in the following sentence. Suggest changes.\",\n",
        "    input_text=\"She seems to believe that the real key to success is working smart and hard.\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LF_v_xgVyBPq"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
            "\n",
            "### Instruction:\n",
            "Make a prediction about what will happen in the next paragraph.\n",
            "\n",
            "### Input:\n",
            "Mary had been living in the small town for many years and had never seen anything like what was coming.\n",
            "\n",
            "### Response:\n",
            "She will be surprised by the event.\n"
          ]
        }
      ],
      "source": [
        "generate_response(\n",
        "    instruction=\"Make a prediction about what will happen in the next paragraph.\",\n",
        "    input_text=\"Mary had been living in the small town for many years and had never seen anything like what was coming.\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I5wj7ZsDhPHH"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
            "\n",
            "### Instruction:\n",
            "Identify a suitable <verb> in the following sentence.\n",
            "\n",
            "### Input:\n",
            "The cat <verb> in the garden.\n",
            "\n",
            "### Response:\n",
            "played\n"
          ]
        }
      ],
      "source": [
        "generate_response(\n",
        "    instruction=\"Identify a suitable <verb> in the following sentence.\",\n",
        "    input_text=\"The cat <verb> in the garden.\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6hZs04c70H2R"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
            "\n",
            "### Instruction:\n",
            "Explain why the quote is appropriate or not for a yoga class.\n",
            "\n",
            "### Input:\n",
            "Don't quit. Suffer now and live the rest of your life as a champion.\n",
            "\n",
            "### Response:\n",
            "This quote is not appropriate for a yoga class because it promotes a competitive mindset and ignores the importance of self-compassion and acceptance.\n"
          ]
        }
      ],
      "source": [
        "generate_response(\n",
        "    instruction=\"Explain why the quote is appropriate or not for a yoga class.\",\n",
        "    input_text=\"Don't quit. Suffer now and live the rest of your life as a champion.\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p3jL-Z8CtqgP"
      },
      "source": [
        "## Push the model to your Hugging Face Hub\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aM84Ti3r02Tz"
      },
      "source": [
        "Optionally, Hugging Face allows to you easily store trained models in their hub."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HIDWBva0_SX4"
      },
      "outputs": [],
      "source": [
        "# Note: The token needs to have \"write\" permission\n",
        "#       You can check it here:\n",
        "#       https://huggingface.co/settings/tokens\n",
        "# Uncomment and run this if you wish to publish the model to Hugging Face Hub\n",
        "# model.push_to_hub(\"my-gemma-finetuned-model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F3jKeiM_HJWJ"
      },
      "source": [
        "In this tutorial, you have learnt how to fine-tune a Gemma model using JORA and convert it to the Hugging Face model format for inference. By leveraging JAX's JIT compilation and tensor-sharding capabilities, you can achieve efficient resource management, enabling accelerated fine-tuning with reduced memory requirements."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_2]Finetune_with_JORA.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
